/**
 * Copyright 2024 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include "c/ddk/graph/attr_value.h"
#include "graph/attr_value.h"

#include <memory>
#include <string>

#include "resource_manager.h"
#include "framework/infra/log/log.h"

template <typename T>
AttrHandle CreateAttr(ResMgrHandle resMgr, T value)
{
    if (resMgr == nullptr) {
        FMK_LOGE("resMgr is nullptr");
        return nullptr;
    }
    ge::AttrValue attrValue = ge::AttrValue::CreateFrom(value);
    std::shared_ptr<ge::AttrValue> attrPtr = std::make_shared<ge::AttrValue>(attrValue);

    auto resMgrPtr = reinterpret_cast<ResourceManager*>(resMgr);
    resMgrPtr->StoreSrcPtr(std::shared_ptr<ge::Base>(attrPtr));

    return attrPtr.get();
}

AttrHandle HIAI_IR_CreateInt64Attr(ResMgrHandle resMgr, int64_t value)
{
    return CreateAttr<int64_t>(resMgr, value);
}

AttrHandle HIAI_IR_CreateBoolAttr(ResMgrHandle resMgr, bool value)
{
    return CreateAttr<bool>(resMgr, value);
}

AttrHandle HIAI_IR_CreateFloatAttr(ResMgrHandle resMgr, float value)
{
    return CreateAttr<float>(resMgr, value);
}

AttrHandle HIAI_IR_CreateStringAttr(ResMgrHandle resMgr, const char* value)
{
    std::string valueStr(value);
    return CreateAttr<std::string>(resMgr, valueStr);
}

AttrHandle HIAI_IR_CreateTensorAttr(ResMgrHandle resMgr, TensorHandle value)
{
    if (resMgr == nullptr) {
        FMK_LOGE("resMgr is nullptr");
        return nullptr;
    }
    auto resMgrPtr = reinterpret_cast<ResourceManager*>(resMgr);
    BasePtr tensorBaseptr = resMgrPtr->GetSrcPtr(value);
    std::shared_ptr<ge::Tensor> tensorPtr = std::dynamic_pointer_cast<ge::Tensor>(tensorBaseptr);

    return CreateAttr<std::shared_ptr<ge::Tensor>>(resMgr, tensorPtr);
}

AttrHandle HIAI_IR_CreateInt64LstAttr(ResMgrHandle resMgr, int64_t value[], uint32_t num)
{
    std::vector<int64_t> val(value, value + num);
    return CreateAttr<std::vector<int64_t>>(resMgr, val);
}

AttrHandle HIAI_IR_CreateBoolLstAttr(ResMgrHandle resMgr, bool value[], uint32_t num)
{
    std::vector<bool> val(value, value + num);
    return CreateAttr<std::vector<bool>>(resMgr, val);
}

AttrHandle HIAI_IR_CreateFloatLstAttr(ResMgrHandle resMgr, float value[], uint32_t num)
{
    std::vector<float> val(value, value + num);
    return CreateAttr<std::vector<float>>(resMgr, val);
}

AttrHandle HIAI_IR_CreateStringLstAttr(ResMgrHandle resMgr, const char *value[], size_t size)
{
    std::vector<std::string> val;
    for (size_t i = 0; i < size; i++) {
        val.push_back(value[i]);
    }
    return CreateAttr<std::vector<std::string>>(resMgr, val);
}
